43c1ce
@@ -21,7 +21,6 @@
import java.security.Principal;
 import java.util.ArrayList;
 import java.util.Collections;
 import java.util.concurrent.Callable;
-
 import javax.servlet.ServletContext;
 import javax.servlet.ServletException;
 import javax.servlet.http.HttpServletRequest;
@@ -52,7 +51,6 @@
import org.springframework.web.context.request.NativeWebRequest;
 import org.springframework.web.context.request.RequestAttributes;
 import org.springframework.web.context.request.RequestContextHolder;
 import org.springframework.web.context.request.ServletRequestAttributes;
-import org.springframework.web.context.request.async.CallableProcessingInterceptor;
 import org.springframework.web.context.request.async.CallableProcessingInterceptorAdapter;
 import org.springframework.web.context.request.async.WebAsyncManager;
 import org.springframework.web.context.request.async.WebAsyncUtils;
@@ -785,6 +783,18 @@
public abstract class FrameworkServlet extends HttpServletBean {
 		// For subclasses: do nothing by default.
 	}
 
+	/**
+	 * Close the WebApplicationContext of this servlet.
+	 * @see org.springframework.context.ConfigurableApplicationContext#close()
+	 */
+	@Override
+	public void destroy() {
+		getServletContext().log("Destroying Spring FrameworkServlet '" + getServletName() + "'");
+		if (this.webApplicationContext instanceof ConfigurableApplicationContext) {
+			((ConfigurableApplicationContext) this.webApplicationContext).close();
+		}
+	}
+
 
 	/**
 	 * Override the parent class implementation in order to intercept PATCH
@@ -873,7 +883,7 @@
public abstract class FrameworkServlet extends HttpServletBean {
 		super.doOptions(request, new HttpServletResponseWrapper(response) {
 			@Override
 			public void setHeader(String name, String value) {
-				if("Allow".equals(name)) {
+				if ("Allow".equals(name)) {
 					value = (StringUtils.hasLength(value) ? value + ", " : "") + RequestMethod.PATCH.name();
 				}
 				super.setHeader(name, value);
@@ -915,15 +925,12 @@
public abstract class FrameworkServlet extends HttpServletBean {
 		LocaleContext localeContext = buildLocaleContext(request);
 
 		RequestAttributes previousAttributes = RequestContextHolder.getRequestAttributes();
-		ServletRequestAttributes requestAttributes = null;
-		if (previousAttributes == null || (previousAttributes instanceof ServletRequestAttributes)) {
-			requestAttributes = new ServletRequestAttributes(request);
-		}
-
-		initContextHolders(request, localeContext, requestAttributes);
+		ServletRequestAttributes requestAttributes = buildRequestAttributes(request, response, previousAttributes);
 
 		WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request);
-		asyncManager.registerCallableInterceptor(FrameworkServlet.class.getName(), getRequestBindingInterceptor(request));
+		asyncManager.registerCallableInterceptor(FrameworkServlet.class.getName(), new RequestBindingInterceptor());
+
+		initContextHolders(request, localeContext, requestAttributes);
 
 		try {
 			doService(request, response);
@@ -950,27 +957,18 @@
public abstract class FrameworkServlet extends HttpServletBean {
 			if (logger.isDebugEnabled()) {
 				if (failureCause != null) {
 					this.logger.debug("Could not complete request", failureCause);
-				} else {
+				}
+				else {
 					if (asyncManager.isConcurrentHandlingStarted()) {
-						if (logger.isDebugEnabled()) {
-							logger.debug("Leaving response open for concurrent processing");
-						}
+						logger.debug("Leaving response open for concurrent processing");
 					}
 					else {
 						this.logger.debug("Successfully completed request");
 					}
 				}
 			}
-			if (this.publishEvents) {
-				// Whether or not we succeeded, publish an event.
-				long processingTime = System.currentTimeMillis() - startTime;
-				this.webApplicationContext.publishEvent(
-						new ServletRequestHandledEvent(this,
-								request.getRequestURI(), request.getRemoteAddr(),
-								request.getMethod(), getServletConfig().getServletName(),
-								WebUtils.getSessionId(request), getUsernameForRequest(request),
-								processingTime, failureCause));
-			}
+
+			publishRequestHandledEvent(request, startTime, failureCause);
 		}
 	}
 
@@ -978,18 +976,43 @@
public abstract class FrameworkServlet extends HttpServletBean {
 	 * Build a LocaleContext for the given request, exposing the request's
 	 * primary locale as current locale.
 	 * @param request current HTTP request
-	 * @return the corresponding LocaleContext
+	 * @return the corresponding LocaleContext, or {@code null} if none to bind
+	 * @see LocaleContextHolder#setLocaleContext
 	 */
 	protected LocaleContext buildLocaleContext(HttpServletRequest request) {
 		return new SimpleLocaleContext(request.getLocale());
 	}
 
-	private void initContextHolders(HttpServletRequest request,
-			LocaleContext localeContext, RequestAttributes attributes) {
+	/**
+	 * Build ServletRequestAttributes for the given request (potentially also
+	 * holding a reference to the response), taking pre-bound attributes
+	 * (and their type) into consideration.
+	 * @param request current HTTP request
+	 * @param response current HTTP response
+	 * @param previousAttributes pre-bound RequestAttributes instance, if any
+	 * @return the ServletRequestAttributes to bind, or {@code null} to preserve
+	 * the previously bound instance (or not binding any, if none bound before)
+	 * @see RequestContextHolder#setRequestAttributes
+	 */
+	protected ServletRequestAttributes buildRequestAttributes(
+			HttpServletRequest request, HttpServletResponse response, RequestAttributes previousAttributes) {
+
+		if (previousAttributes == null || previousAttributes instanceof ServletRequestAttributes) {
+			return new ServletRequestAttributes(request);
+		}
+		else {
+			return null;  // preserve the pre-bound RequestAttributes instance
+		}
+	}
+
+	private void initContextHolders(
+			HttpServletRequest request, LocaleContext localeContext, RequestAttributes requestAttributes) {
 
-		LocaleContextHolder.setLocaleContext(localeContext, this.threadContextInheritable);
-		if (attributes != null) {
-			RequestContextHolder.setRequestAttributes(attributes, this.threadContextInheritable);
+		if (localeContext != null) {
+			LocaleContextHolder.setLocaleContext(localeContext, this.threadContextInheritable);
+		}
+		if (requestAttributes != null) {
+			RequestContextHolder.setRequestAttributes(requestAttributes, this.threadContextInheritable);
 		}
 		if (logger.isTraceEnabled()) {
 			logger.trace("Bound request context to thread: " + request);
@@ -1006,17 +1029,17 @@
public abstract class FrameworkServlet extends HttpServletBean {
 		}
 	}
 
-	private CallableProcessingInterceptor getRequestBindingInterceptor(final HttpServletRequest request) {
-		return new CallableProcessingInterceptorAdapter() {
-			@Override
-			public <T> void preProcess(NativeWebRequest webRequest, Callable<T> task) {
-				initContextHolders(request, buildLocaleContext(request), new ServletRequestAttributes(request));
-			}
-			@Override
-			public <T> void postProcess(NativeWebRequest webRequest, Callable<T> task, Object concurrentResult) {
-				resetContextHolders(request, null, null);
-			}
-		};
+	private void publishRequestHandledEvent(HttpServletRequest request, long startTime, Throwable failureCause) {
+		if (this.publishEvents) {
+			// Whether or not we succeeded, publish an event.
+			long processingTime = System.currentTimeMillis() - startTime;
+			this.webApplicationContext.publishEvent(
+					new ServletRequestHandledEvent(this,
+							request.getRequestURI(), request.getRemoteAddr(),
+							request.getMethod(), getServletConfig().getServletName(),
+							WebUtils.getSessionId(request), getUsernameForRequest(request),
+							processingTime, failureCause));
+		}
 	}
 
 	/**
@@ -1032,6 +1055,7 @@
public abstract class FrameworkServlet extends HttpServletBean {
 		return (userPrincipal != null ? userPrincipal.getName() : null);
 	}
 
+
 	/**
 	 * Subclasses must implement this method to do the work of request handling,
 	 * receiving a centralized callback for GET, POST, PUT and DELETE.
@@ -1049,19 +1073,6 @@
public abstract class FrameworkServlet extends HttpServletBean {
 			throws Exception;
 
 
-	/**
-	 * Close the WebApplicationContext of this servlet.
-	 * @see org.springframework.context.ConfigurableApplicationContext#close()
-	 */
-	@Override
-	public void destroy() {
-		getServletContext().log("Destroying Spring FrameworkServlet '" + getServletName() + "'");
-		if (this.webApplicationContext instanceof ConfigurableApplicationContext) {
-			((ConfigurableApplicationContext) this.webApplicationContext).close();
-		}
-	}
-
-
 	/**
 	 * ApplicationListener endpoint that receives events from this servlet's WebApplicationContext
 	 * only, delegating to {@code onApplicationEvent} on the FrameworkServlet instance.
@@ -1073,4 +1084,28 @@
public abstract class FrameworkServlet extends HttpServletBean {
 		}
 	}
 
+
+	/**
+	 * CallableProcessingInterceptor implementation that initializes and resets
+	 * FrameworkServlet's context holders, i.e. LocaleContextHolder and RequestContextHolder.
+	 */
+	private class RequestBindingInterceptor extends CallableProcessingInterceptorAdapter {
+
+		@Override
+		public <T> void preProcess(NativeWebRequest webRequest, Callable<T> task) {
+			HttpServletRequest request = webRequest.getNativeRequest(HttpServletRequest.class);
+			if (request != null) {
+				HttpServletResponse response = webRequest.getNativeRequest(HttpServletResponse.class);
+				initContextHolders(request, buildLocaleContext(request), buildRequestAttributes(request, response, null));
+			}
+		}
+		@Override
+		public <T> void postProcess(NativeWebRequest webRequest, Callable<T> task, Object concurrentResult) {
+			HttpServletRequest request = webRequest.getNativeRequest(HttpServletRequest.class);
+			if (request != null) {
+				resetContextHolders(request, null, null);
+			}
+		}
+	}
+
 }
